open Base
open OUnit2

module type TESTABLE = sig
  type t
  val equal: t Equal.t
  val to_string: t -> string
end

module String_ = struct
  type t = string
  let equal = String.equal
  let to_string x = Printf.sprintf "%S" x
end

module Option_(X: TESTABLE) = struct
  type t = X.t option
  let equal = Option.equal X.equal
  let to_string = function
    | Some(x) -> "Some(" ^ X.to_string x ^ ")"
    | None -> "None"
end

module List_(X: TESTABLE) = struct
  type t = X.t list
  let equal = List.equal ~equal:X.equal
  let to_string l =
    Printf.sprintf "[%s]" @@
    String.concat ~sep:"; " @@
    List.map ~f:X.to_string l
end

module ListUnordered_(X: TESTABLE) = struct
  type t = X.t list
  let equal xs ys =
    List.for_all (xs @ ys)
      ~f:(fun e -> List.count xs ~f:(X.equal e) = List.count ys ~f:(X.equal e))
  let to_string l =
    Printf.sprintf "{%s}" @@
    String.concat ~sep:"; " @@
    List.map ~f:X.to_string l
end

module Pair_(A: TESTABLE)(B: TESTABLE) = struct
  type t = A.t * B.t
  let equal (a, x) (b, y) = A.equal a b && B.equal x y
  let to_string (x, y) =
    Printf.sprintf "(%s, %s)"
      (A.to_string x)
      (B.to_string y)
end

module Map_(K: sig
    type t
    type comparator_witness
    val to_string : t -> string
  end)
    (V: TESTABLE)
=
struct
  type t = (K.t, V.t, K.comparator_witness) Map.t
  let equal (a: t) (b: t) = Map.equal V.equal a b
  let to_string m =
    Printf.sprintf "{%s}" @@
    String.concat ~sep:"; " @@
    List.map (Map.to_alist m)
      ~f:(fun (k, v) -> Printf.sprintf "%s := %s"
                          (K.to_string k)
                          (V.to_string v))
end

module IntOption = Option_(Int)
module StringOption = Option_(String_)
module StringList = List_(String_)

let asrt (type a) (module X: TESTABLE with type t = a) e a =
  assert_equal ~cmp:X.equal ~printer:X.to_string e a
let t name m e a =
  name>::fun _ -> asrt m e a
let ts m l =
  test_list
    (List.map l ~f:(fun (name, e, a) -> t name m e a))

let log, t_log =
  let out = ref [] in
  let log x =
    out := x::!out
  in
  let t_log name expected func =
    out := [];
    func ();
    t name (module StringList) expected (List.rev !out)
  in
  log, t_log

let logf fmt = Printf.ksprintf log fmt

let tfail name e run =
  name>::fun _ -> match Result.try_with run with
    | Ok(_) -> assert_failure "no exception raised"
    | Error(a) ->
       assert_equal
         ~cmp:(fun e a -> match e, a with
             | Sexp.List [ Sexp.Atom(e_exn) ; Sexp.Atom(e_msg) ],
               Sexp.List [ Sexp.Atom(a_exn) ; Sexp.Atom(a_msg) ] ->
                String.(e_exn = a_exn && is_substring a_msg ~substring:e_msg)
             | Sexp.List [ Sexp.Atom(e_exn) ],
               Sexp.List [ Sexp.Atom(a_exn) ] ->
                String.(e_exn = a_exn)
             | _, _ -> false)
         ~printer:Sexp.to_string
         (Exn.sexp_of_t e)
         (Exn.sexp_of_t a)

let tfails ts =
  test_list @@
  List.map ts ~f:(fun (name, exn, f) ->
      tfail name exn f)
